# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.

"""
This module provides helper functions that are generated by the preprocessor.
The preprocessor read through python's ast and changes the input code.
"""

from typing import Callable, Iterator, Optional, overload
from typing_extensions import deprecated
import warnings

from .utils.logger import log
from .common import *

from ._mlir_helpers.arith import ArithValue

class Executor:
    """
    The Executor class handles dynamic and compile-time (constexpr) execution
    of "for" loops and "if-else-elif" statements.

    Methods:
        set_functions:  Assigns the functions for checking loop bounds and
                        conditional evaluation.

        for_execute: Generates MLIR for OP
        while_execute: Generates MLIR while OP
        if_execute: generate MLIR if OP
    """

    def __init__(self):
        self._is_dynamic_expression = None
        self._loop_execute_range_dynamic = None
        self._if_dynamic = None
        self._while_dynamic = None
        self._compare_executor = None
        self._any_executor = None
        self._all_executor = None

    def set_functions(
        self,
        is_dynamic_expression: Callable,
        loop_execute_range_dynamic: Callable,
        if_dynamic: Callable,
        while_dynamic: Callable,
        compare_executor: Callable,
        any_executor: Callable = None,
        all_executor: Callable = None,
    ):
        self._is_dynamic_expression = is_dynamic_expression
        self._loop_execute_range_dynamic = loop_execute_range_dynamic
        self._if_dynamic = if_dynamic
        self._while_dynamic = while_dynamic
        self._compare_executor = compare_executor
        self._any_executor = any_executor
        self._all_executor = all_executor

    @staticmethod
    def convert_to_list(x):
        """This function is used to convert x to a list.
        If x is None, return an empty list.
        If x is not a list, return a list containing x.
        Otherwise, return x itself.
        """
        if x is None:
            return []
        if not isinstance(x, list):
            return [x]
        return x

    @staticmethod
    def converge_ret_val(res):
        """This function is used to converge res (the return value) of the function.
        If res is None, return None.
        If res is a list and has only one element, return the element.
        Otherwise, return res itself.
        """
        if res is None:
            return res
        elif isinstance(res, list) and len(res) == 1:
            return res[0]
        return res

    @staticmethod
    def for_constexpr(
        func: Callable,
        start: int,
        stop: int,
        step: int,
        used_args: list,
        iter_args: list,
    ):
        log().debug("start [%s] stop [%s] step [%s]", start, stop, step)
        loop_results = iter_args
        log().debug("iter_args [%s]", iter_args)
        for i in range(start, stop, step):
            log().debug("i  [%s] iter_args  [%s]", i, iter_args)
            loop_results = func(i, *used_args, *loop_results)
            log().debug("loop_results  [%s]", loop_results)
            if loop_results is None:
                loop_results = []
            if not isinstance(loop_results, list):
                loop_results = [loop_results]

        log().debug("done loop_results [%s]", loop_results)
        return Executor.converge_ret_val(loop_results)

    def for_execute(
        self,
        func,
        start,
        stop,
        step,
        used_args=[],
        iter_args=[],
        iter_arg_names=[],
        unroll=-1,
        unroll_full=False,
        pipelining=None,
    ):
        assert (
            self._loop_execute_range_dynamic
        ), "Functions must be set before execution."
        log().debug("start [%s] stop [%s] step [%s]", start, stop, step)

        return self._loop_execute_range_dynamic(
            func,
            start,
            stop,
            step,
            used_args,
            iter_args,
            iter_arg_names,
            unroll,
            unroll_full,
            pipelining,
        )

    def if_execute(
        self,
        pred,
        then_block: Callable,
        else_block: Optional[Callable] = None,
        used_args=[],
        yield_args=[],
        yield_arg_names=[],
    ):
        assert self._if_dynamic, "Functions must be set before execution."

        # MLIR generation
        return self._if_dynamic(
            pred, then_block, else_block, used_args, yield_args, yield_arg_names
        )

    def while_execute(
        self,
        pred,
        while_before_block: Callable,
        while_after_block: Callable,
        used_args=[],
        yield_args=[],
        yield_arg_names=[],
    ):
        assert self._while_dynamic, "Functions must be set before execution."

        # MLIR generation
        return self._while_dynamic(
            while_before_block,
            while_after_block,
            used_args,
            yield_args,
            yield_arg_names,
        )


# =============================================================================
# Decorator
# =============================================================================

executor = Executor()


def loop_selector(
    start,
    stop,
    step,
    *,
    used_args=[],
    iter_args=[],
    iter_arg_names=[],
    unroll=-1,
    unroll_full=False,
    pipelining=None,
):
    log().debug(
        "start [%s] stop [%s] step [%s] used_args [%s] iter_args [%s] unroll [%s] unroll_full [%s] pipelining [%s]",
        start,
        stop,
        step,
        used_args,
        iter_args,
        unroll,
        unroll_full,
        pipelining,
    )
    from .typing import Integer, Numeric

    def _maybe_upcast(value):
        if isinstance(value, Integer):
            value = value.ir_value()

        return value

    start = _maybe_upcast(start)
    stop = _maybe_upcast(stop)
    step = _maybe_upcast(step)

    def ir_loop(func):
        return executor.for_execute(
            func,
            start,
            stop,
            step,
            used_args,
            iter_args,
            iter_arg_names,
            unroll,
            unroll_full,
            pipelining,
        )

    return ir_loop


def if_selector(pred, used_args=[], yield_args=[]):
    log().debug("pred [%s] used_args [%s] yield_args [%s]", pred, used_args, yield_args)
    # Handle Numeric types here?

    from .typing import Numeric

    if isinstance(pred, Numeric):
        pred = pred.value

    def ir_loop(func):
        return func(pred, *used_args, *yield_args)

    return ir_loop


def while_selector(pred, used_args=[], yield_args=[]):
    def ir_while_loop(func):
        return func(pred, *used_args, *yield_args)

    return ir_while_loop


def while_executor(
    pred,
    while_before_block: Callable,
    while_after_block: Callable,
    used_args=[],
    yield_args=[],
    yield_arg_names=[],
):
    return executor.while_execute(
        pred,
        while_before_block,
        while_after_block,
        used_args,
        yield_args,
        yield_arg_names,
    )


def if_executor(
    pred,
    then_block: Callable,
    else_block: Optional[Callable] = None,
    used_args=[],
    yield_args=[],
    yield_arg_names=[],
):
    return executor.if_execute(
        pred, then_block, else_block, used_args, yield_args, yield_arg_names
    )


# =============================================================================
# Range
# =============================================================================


class range:
    """
    A range-like object for dynamic loop iteration in the DSL.

    This class provides a range interface similar to Python's built-in range,
    but is designed to be preprocessed into constructs for dynamic
    loop execution.

    The class supports both single-argument (stop) and three-argument
    (start, stop, step) constructors with additional parameters for loop
    optimization:

    - unroll: Number of iterations to unroll (0 or 1 = no unrolling)
    - unroll_full: Whether to fully unroll the loop
    - pipelining: Compiler generated pipeline configuration
    """
    @overload
    def __new__(cls, stop, unroll=0, unroll_full=False, pipelining=None):
        pass

    @overload
    def __new__(cls, start, stop, step, unroll=0, unroll_full=False, pipelining=None):
        pass

    def __new__(cls, *args, **kwargs):
        raise DSLRuntimeError("dynamic range should be always preprocessed to IR")

    def __iter__(self) -> Iterator[int]:
        raise DSLRuntimeError("dynamic range should be always preprocessed to IR")


@deprecated(
    "range_dynamic is deprecated and will be removed in the future, please remove it."
)
def range_dynamic(*args, **kwargs):
    raise DSLRuntimeError("range_dynamic should be always preprocessed to IR")


def range_constexpr(*args):
    raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.")

# =============================================================================
# If expressions
# =============================================================================


def const_expr(expression):
    """
    This function is used to check if the expression is a python value.
    If the expression is a python value, return the boolean value of the expression.
    If the expression is a dynamic expression, raise an error.
    """
    from .typing import Numeric

    failed = False

    if isinstance(expression, Numeric):
        if isinstance(expression.value, (int, float, bool)):
            return expression.value
        else:
            failed = True
    elif executor._is_dynamic_expression(expression):
        failed = True

    if failed:
        raise DSLRuntimeError(
            f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).",
            context={
                "If your expression depends on dynamic values": "Remove `const_expr()`",
            },
        )
    return expression


@deprecated(
    "dynamic_expr is deprecated and will be removed in the future, please remove it."
)
def dynamic_expr(expression):
    return expression


# =============================================================================
# Assertion & casting
# =============================================================================


def assert_executor(test, msg=None):
    from .typing import Numeric

    fail = False
    # Implicit convert dynamic expression to bool is not allowed
    # So here explicitly do a None check
    if test is not None and executor._is_dynamic_expression(test):
        if isinstance(test, Numeric):
            try:
                test = test.to(bool)
            except:
                fail = True
        else:
            fail = True

    if not fail:
        assert test, msg
    else:
        raise DSLRuntimeError(
            "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
            suggestion = "Please replace with runtime assert."
        )


def bool_cast(value):
    if executor._is_dynamic_expression(value):
        raise DSLRuntimeError(
            "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.",
            suggestion = "Please explicitly convert to boolean with expressions like comparision."
        )
    return bool(value)

def compare_executor(left, comparators, ops):
    """
    Executes comparison operations with a left operand and a list of comparators.

    Args:
        left: The leftmost value in the comparison chain
        comparators: A list of values to compare against
        ops: A list of comparison operators to apply

    Returns:
        The result of the comparison chain

    Raises:
        AssertionError: If the executor function is not set before execution
    """
    assert (
        executor._compare_executor is not None
    ), "Function must be set before execution."
    return executor._compare_executor(left, comparators, ops)


def any_executor(iterable):
    """Executes the 'any' operation on an iterable, handling both dynamic and static expressions.

    :param iterable: An iterable to check if any elements evaluate to True
    :type iterable: Iterable
    :return: boolean of Python value or IR value
    :rtype: bool or cutlass.Boolean

    """
    if executor._any_executor and executor._is_dynamic_expression(iterable):
        return executor._any_executor(iterable)
    else:
        return any(iterable)


def all_executor(iterable):
    """Executes the 'all' operation on an iterable, handling both dynamic and static expressions.

    :param iterable: An iterable to check if all elements evaluate to True
    :type iterable: Iterable
    :return: boolean of Python value or IR value
    :rtype: bool or cutlass.Boolean
    """
    if executor._all_executor and executor._is_dynamic_expression(iterable):
        return executor._all_executor(iterable)
    else:
        return all(iterable)


# =============================================================================
# Control flow checks
# =============================================================================
def range_value_check(*args):
    """
    Ensure all `range_constexpr` bounds are compile-time constants (Python ints).
    """
    try:
        args = tuple(arg.__index__() for arg in args)

        # Compute range size and warn if it's too large
        start = 0
        end = 0
        step = 1
        if len(args) == 1:
            end = args[0]
        elif len(args) == 2:
            start = args[0]
            end = args[1]
        elif len(args) == 3:
            start = args[0]
            end = args[1]
            step = args[2]

        range_length = (abs(end - start) - 1) // abs(step) + 1
        if range_length >= 64:
            warnings.warn(
                f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.",
                category=UserWarning,
                stacklevel=2,
            )

        return (start, end, step)
    except:
        raise DSLRuntimeError(
            "`range_constexpr` requires constexpr (compile-time constant) for all arguments.",
            suggestion="Use `range` instead of `range_constexpr`.",
        )


def range_perf_warning(filename, lineno, *args):
    has_dynamic_expr = False
    for arg in args:
        if executor._is_dynamic_expression(arg):
            has_dynamic_expr = True
            break
    if not has_dynamic_expr:
        warnings.warn_explicit(
            (
                "This loop is no longer unrolled and may cause performance regression. "
                "Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants."
            ),
            category=UserWarning,
            filename=filename,
            lineno=lineno,
        )
